import torch
import numpy as np
import tqdm
import wandb

from torch import nn
import torch.distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily
from torch.distributions.multivariate_normal import MultivariateNormal
from bgflow.bg import sampling_efficiency


device = "cuda" if torch.cuda.is_available() else "cpu"


def load_weights(bg, path):
    from_them = not (path[-3:] == ".pt")
    print(f"Loading path {path} I say it is from them {from_them}")
    if not from_them:
        bg.load_state_dict(torch.load(path))
    else:
        saved_dict = torch.load(path, weights_only=False)
        if "model_state_dict" in saved_dict:
            saved_dict = saved_dict["model_state_dict"]
        if "_prior._std" in saved_dict:
            # The saved files are not all saved the same way ...
            bg.load_state_dict(saved_dict)
        else:
            bg._flow.load_state_dict(saved_dict)


def fm_train_step_ot(x1, prior, model, pot, sigma=0.01):
    batchsize = x1.shape[0]
    t = torch.rand(batchsize, 1).to(device)
    x0 = prior.sample(batchsize)

    # Resample x0, x1 according to transport matrix
    a1, b1 = pot.unif(x0.size()[0]), pot.unif(x1.size()[0])
    M = torch.cdist(x0, x1) ** 2
    M = M / M.max()
    pi = pot.emd(a1, b1, M.detach().cpu().numpy())
    # Sample random interpolations on pi
    p = pi.flatten()
    p = p / p.sum()
    choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batchsize)
    i, j = np.divmod(choices, pi.shape[1])
    x0 = x0[i]
    x1 = x1[j]
    # calculate regression loss
    mu_t = x0 * (1 - t) + x1 * t
    sigma_t = sigma
    noise = prior.sample(batchsize)
    x = mu_t + sigma_t * noise
    ut = x1 - x0
    vt = model.flow._dynamics._dynamics._dynamics_function(t, x)
    return torch.mean((vt - ut) ** 2)


def train_loop(
    epochs,
    model,
    train_step,
    batches,
    optim,
    grad_acc_steps=1,
    grad_clipping=False,
    batches_per_epoch=None,
    save_callback=None,
):
    optim.zero_grad()
    for epoch in range(epochs):
        counter = 0
        loss_ = 0
        for x1 in batches():
            loss = train_step(x1)
            loss.backward()
            if grad_clipping:
                max_norm = 1.0  # maximum norm of the gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            counter += 1
            loss_ += loss.item()

            if (counter % max(grad_acc_steps, 1)) == 0 or counter == batches_per_epoch:
                optim.step()
                optim.zero_grad()
                wandb.log({f"epoch": epoch, f"train-loss": loss_ / grad_acc_steps})
                loss_ = 0
        if save_callback is not None:
            save_callback(epoch)


# samples with or without weights
def get_stats_bg(bg, n_samples=200, n_sample_batches=50):

    bg.flow._integrator_atol = 1e-5
    bg.flow._integrator_rtol = 1e-3
    bg.flow._use_checkpoints = False
    bg.flow._kwargs = {}

    log_w_np = np.empty(shape=(0))

    energies_np = np.empty(shape=(0))
    for _ in tqdm.tqdm(range(n_sample_batches)):
        with torch.no_grad():
            samples, latent, dlogp = bg.sample(
                n_samples, with_latent=True, with_dlogp=True
            )
            # latent = latent[0]
            log_weights = (
                bg.log_weights_given_latent(samples, latent, dlogp, normalize=False)
                .detach()
                .cpu()
                .numpy()
            )

            log_w_np = np.append(log_w_np, log_weights)
            energies = bg._target.energy(samples).detach().cpu().numpy()
            energies_np = np.append(energies_np, energies)

    ess = sampling_efficiency(torch.from_numpy(log_w_np)).item()
    return log_w_np, ess


class HutchinsonEstimatorDifferentiable(torch.nn.Module):
    """
    Estimation of the divergence of a dynamics function with the Hutchinson Estimator [1].
    [1] A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines, Hutchinson
    """

    def __init__(self, rademacher=True):
        super().__init__()
        self._rademacher = rademacher
        self._reset_noise = True

    def reset_noise(self, reset_noise=True):
        """
        Resets the noise vector.
        """

        self._reset_noise = reset_noise

    def forward(self, dynamics, t, xs):
        """
        Computes the change of the system `dxs` due to a time independent dynamics function.
        Furthermore, also estimates the change of log density, which is equal to the divergence of the change `dxs`,
        with the Hutchinson Estimator.
        This is done with either Rademacher or Gaussian noise.

        Parameters
        ----------
        dynamics : torch.nn.Module
            A dynamics function that computes the change of the system and its density.
        t : PyTorch tensor
            The current time
        xs : PyTorch tensor
            The current configuration of the system

        Returns
        -------
        dxs, -divergence: PyTorch tensors
            The combined state update of shape `[n_batch, n_dimensions]`
            containing the state update of the system state `dx/dt`
            (`dxs`) and the negative update of the log density (`-divergence`)
        """

        with torch.set_grad_enabled(True):
            xs.requires_grad_(True)
            dxs = dynamics(t, xs)

            assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]"
            system_dim = dxs.shape[-1]

            if self._reset_noise == True:
                if self._rademacher == True:
                    self._noise = (
                        torch.randint(low=0, high=2, size=xs.shape).to(xs) * 2 - 1
                    )
                else:
                    self._noise = torch.randn_like(xs)

            noise_ddxs = torch.autograd.grad(dxs, xs, self._noise, create_graph=True)[0]
            divergence = torch.sum(
                (noise_ddxs * self._noise).view(-1, system_dim), 1, keepdim=True
            )

        return dxs, -divergence


class BruteForceEstimatorFastDifferentiable(torch.nn.Module):
    """
    Exact bruteforce estimation of the divergence of a dynamics function.
    """

    def __init__(self):
        super().__init__()

    def forward(self, dynamics, t, xs):
        """
        Computes the change of the system `dxs` due to a time independent dynamics function.
        Furthermore, also computes the exact change of log density
        which is equal to the divergence of the change `dxs`.

        Parameters
        ----------
        dynamics : torch.nn.Module
            A dynamics function that computes the change of the system and its density.
        t : PyTorch tensor
            The current time
        xs : PyTorch tensor
            The current configuration of the system

        Returns
        -------
        dxs, -divergence: PyTorch tensors
            The combined state update of shape `[n_batch, n_dimensions]`
            containing the state update of the system state `dx/dt`
            (`dxs`) and the negative update of the log density (`-divergence`)
        """

        with torch.set_grad_enabled(True):
            x = [xs[:, [i]] for i in range(xs.size(1))]

            dxs = dynamics(t, torch.cat(x, dim=1))

            assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]"
            divergence = 0
            for i in range(xs.size(1)):
                divergence += torch.autograd.grad(
                    dxs[:, [i]],
                    x[i],
                    torch.ones_like(dxs[:, [i]]),
                    retain_graph=True,
                    create_graph=True,
                )[0]

        return dxs, -divergence.view(-1, 1)


class AugmentedAdjointDyn(torch.nn.Module):
    """Black box dynamics that allows to use any dynamics function.
    The divergence of the dynamics is computed with a divergence estimator.
    """

    def __init__(self, bb_dyn, compute_divergence=True):
        super().__init__()
        self._bb_dyn = bb_dyn
        self._compute_divergence = compute_divergence

    def forward(self, t, xsdlogQdx):
        """
        Computes the change of the system `dxs` at state `xs` and
        time `t`. Furthermore, can also compute the change of log density
        which is equal to the divergence of the change.

        Parameters
        ----------
        t : PyTorch tensor
            The current time
        xs : PyTorch tensor
            The current configuration of the system

        Returns
        -------
        (*dxs, divergence): Tuple of PyTorch tensors
            The combined state update of shape `[n_batch, n_dimensions]`
            containing the state update of the system state `dx/dt`
            (`dxs`) and the update of the log density (`dlogp`)
        """
        if self._compute_divergence:
            with torch.set_grad_enabled(True):
                xs = xsdlogQdx[..., 0]
                dlogQdx = xsdlogQdx[..., 1]
                xs.requires_grad_(True)
                dxs, divergence = self._bb_dyn(t, xs)

                ddlogQdx = torch.autograd.grad(
                    (dxs * (-dlogQdx.detach())).sum() + divergence.sum(),
                    xs,
                    retain_graph=True,
                )[0]
            return torch.stack([dxs, ddlogQdx], dim=2), divergence
        return self._bb_dyn(t, *xs)


def clip_norm(grads, max_norm=1.0, norm_type=2.0):
    """
    Adapted from clip_grad_norm -> normalizes the gradients norms, so that the total norm is <= max_norm
    """
    og_shape = grads.shape
    grads = grads.reshape((og_shape[0], -1))
    total_norm = torch.norm(
        torch.norm(grads, norm_type, dim=1).to(grads.device), norm_type
    )
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    grads = grads.reshape(og_shape)
    return grads * clip_coef_clamped


def path_gradient(
    x,
    prior,
    target,
    flow,
    x_includes_grads=False,
    force_clipping=False,
):
    if x_includes_grads:
        x, dlogQdx = x
        dlogQdx = dlogQdx.reshape(x.shape)
        mask = torch.isfinite(dlogQdx).any(dim=1)
        x = x[mask]
        dlogQdx = dlogQdx[mask]
    x1 = x.detach().clone().requires_grad_()
    logp = -target.energy(x1)
    if not x_includes_grads:
        dlogQdx = torch.autograd.grad(logp.sum(), x1)[0]
    epsdlogQdeps, dlogp = flow(torch.stack((x1, dlogQdx), dim=2), inverse=True)
    eps = epsdlogQdeps[..., 0]
    dlogQdeps = epsdlogQdeps[..., 1]

    logp_eps = -prior.energy(eps)
    dlogPdeps = torch.autograd.grad(logp_eps.sum(), eps)[0]

    log_q = logp_eps + dlogp
    logw = logp - log_q
    dLdeps = dlogQdeps - dlogPdeps
    if force_clipping:
        dLdeps = clip_norm(dLdeps)
    gradterm = (eps * dLdeps.detach()).mean()
    return logw.mean().detach() + gradterm - gradterm.detach()
